import dgl
from dgl.dataloading import Sampler

import torch

from data.load_dataset import load_ogb, load_partitioned_graphs, load_complete_graphs

import numpy as np
import unittest

class TestRearrangeActivations(unittest.TestCase):
    def test_rearrange_activations(self):
        # Test case 1
        local_src_nid = torch.tensor([10, 20, 30, 40])
        global_src_nid = torch.tensor([50, 60, 70, 80])
        activation_nid = torch.tensor([40, 30, 60, 80, 20, 70, 10, 50])

        local_indices, global_indices = rearrange_activations(local_src_nid, global_src_nid, activation_nid)
        print(local_indices.dtype)

        expected_local_indices = torch.tensor([6, 4, 1, 0])
        expected_global_indices = torch.tensor([7, 2, 5, 3])

        self.assertTrue(torch.all(local_indices == expected_local_indices))
        self.assertTrue(torch.all(global_indices == expected_global_indices))

        # Test case 2
        local_src_nid = torch.tensor([20, 40])
        global_src_nid = torch.tensor([60, 80])
        activation_nid = torch.tensor([40, 60, 80, 20])

        local_indices, global_indices = rearrange_activations(local_src_nid, global_src_nid, activation_nid)

        expected_local_indices = torch.tensor([3, 0])
        expected_global_indices = torch.tensor([1, 2])

        self.assertTrue(torch.all(local_indices == expected_local_indices))
        self.assertTrue(torch.all(global_indices == expected_global_indices))

def rearrange_activations(local_src_nid, global_src_nid, activation_nid):
    """
    Rearrange activation values based on local and global node IDs.

    Args:
        local_src_nid (torch.Tensor): Local source node IDs.
        global_src_nid (torch.Tensor): Global source node IDs.
        activation_nid (torch.Tensor): Activation node IDs.

    Returns:
        tuple: A tuple containing the following elements:
            - local_values (torch.Tensor): Rearranged local activation values.
            - global_values (torch.Tensor): Rearranged global activation values.
    """
    assert len(local_src_nid) + len(global_src_nid) == activation_nid.shape[0]
    
    # Find the indices of local and global node IDs in activation_nid
    local_indices = torch.tensor([torch.where(activation_nid == nid)[0][0] for nid in local_src_nid])
    global_indices = torch.tensor([torch.where(activation_nid == nid)[0][0] for nid in global_src_nid])

    # Transform the indices to long tensor
    local_indices = local_indices.long()
    global_indices = global_indices.long()

    return local_indices, global_indices

class LocalNeighborSampler(Sampler):
    def __init__(self, fanouts):
        super().__init__()
        self.fanouts = fanouts

    def sample(self, local_graph, seeds):
        # Convert the seeds to a longtensor
        seeds = seeds.type(torch.LongTensor)
        dummy_empty_seeds = torch.LongTensor([])

        blocks = []
        transform_indices = []
        for i in range(len(self.fanouts)):
            fanout = self.fanouts[i]
            # For each seed node, sample ``fanout`` neighbors.
            frontier = dgl.sampling.sample_neighbors(
                local_graph, seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)

            dummy_remote_frontier = dgl.sampling.sample_neighbors(
                local_graph, dummy_empty_seeds, fanout, replace=True)
            dummy_remote_block = dgl.to_block(dummy_remote_frontier, dummy_empty_seeds)
            # Assert there is no edge in the dummy remote block
            assert dummy_remote_block.number_of_edges() == 0

            block_pair = (block, dummy_remote_block)

            # Obtain the seed nodes for next layer.
            blocks.insert(0, block_pair)

            if i < len(self.fanouts) - 1:
                seeds = block.srcdata[dgl.NID]
                
                # We create a dummy transform indices for the local block. As there is no
                # remote block, we only need the transform indices to be from [0 to len(seeds)-1]
                local_indices = torch.arange(len(seeds))
                remote_indices = torch.tensor([])
                local_indices = local_indices.long()
                remote_indices = remote_indices.long()

                transform_indices.insert(0, (local_indices, remote_indices))

        return blocks, transform_indices
    
class GlobalNeighborSampler(Sampler):
    def __init__(self, fanouts, local_mask):
        super().__init__()
        self.fanouts = fanouts
        self.local_mask = local_mask

    def sample(self, complete_graph, seeds):
        # Convert the seeds to a longtensor
        seeds = seeds.type(torch.LongTensor)

        # Extract the local seeds and global seeds
        local_seeds = seeds[self.local_mask[seeds]]
        remote_seeds = seeds[~self.local_mask[seeds]]

        blocks = []
        transform_indices = []
        for i in range(len(self.fanouts)):
            fanout = self.fanouts[i]

            # For each seed node, sample ``fanout`` neighbors.
            local_frontier = dgl.sampling.sample_neighbors(
                complete_graph, local_seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            local_block = dgl.to_block(local_frontier, local_seeds)

            # For each seed node, sample ``fanout`` neighbors.
            remote_frontier = dgl.sampling.sample_neighbors(
                complete_graph, remote_seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            remote_block = dgl.to_block(remote_frontier, remote_seeds)

            block_pair = (local_block, remote_block)
            blocks.insert(0, block_pair)

            if i < len(self.fanouts) - 1:
                # Obtain the seed nodes for next layer by concatenating the local and global seeds
                seeds = torch.cat([local_block.srcdata[dgl.NID], remote_block.srcdata[dgl.NID]])

                # Again extract the local seeds and global seeds
                local_seeds = seeds[self.local_mask[seeds]]
                remote_seeds = seeds[~self.local_mask[seeds]]

                # Construct the indices for mapping the output activations to the current blocks
                combined_activation_nids = torch.cat([local_seeds, remote_seeds])
                local_indices, remote_indices = rearrange_activations(local_block.srcdata[dgl.NID], remote_block.srcdata[dgl.NID], combined_activation_nids)
                transform_indices.insert(0, (local_indices, remote_indices))

        return blocks, transform_indices

if __name__ == '__main__':
    # unittest.main()
    
    # Write a test for the global sampler
    # Load complete graph
    complete_graph = load_complete_graphs('ogbn-arxiv')

    # Load partitioned graph
    path_to_partitioned_dataset = 'partitioned_dataset/ogbn-arxiv_metis_5.bin'
    partitioned_graphs = load_partitioned_graphs(path_to_partitioned_dataset)
    local_graph = partitioned_graphs[0]

    # Get all the training nodes
    train_nid_in_local_graph = local_graph.ndata['train_mask'].nonzero().squeeze()

    # Get training node original IDs in the complete graph
    local_training_node_nid = local_graph.ndata['_ID'][train_nid_in_local_graph]
    global_training_node_nid = complete_graph.ndata['train_mask'].nonzero().squeeze()
    # print(local_training_node_nid)
    # print(global_training_node_nid)

    # Get all the local nodes' IDs in the complete graph
    local_node_id = local_graph.ndata['_ID']

    # Build the local mask
    local_mask = torch.zeros(complete_graph.number_of_nodes(), dtype=torch.bool)
    local_mask[local_node_id] = True

    # Build global data loader
    fanouts = [5, 10, 15]
    local_sampler = LocalNeighborSampler(fanouts)
    global_sampler = GlobalNeighborSampler(fanouts, local_mask)

    # Get one blocks from the global data loader with next iter
    seeds = local_training_node_nid[:1000]

    blocks, transform_indices = local_sampler.sample(local_graph, seeds)
    assert len(blocks) == len(fanouts)
    assert len(transform_indices) == len(fanouts) - 1

    # Print the src/dst nodes of the blocks
    for i, block_pair in enumerate(blocks):
        local_block, global_block = block_pair
        print(f"Block {i}")
        print(f"Local block: {local_block.srcdata[dgl.NID]} -> {local_block.dstdata[dgl.NID]}")
        print(f"Global block: {global_block.srcdata[dgl.NID]} -> {global_block.dstdata[dgl.NID]}")
        print()

    # Print the transform indices
    for i, indices_pair in enumerate(transform_indices):
        local_indices, global_indices = indices_pair
        print(f"Transform indices {i}")
        print(f"Local indices: {local_indices}")
        print(f"Global indices: {global_indices}")
        print(f"Length of local indices: {len(local_indices)}")
        print(f"Length of global indices: {len(global_indices)}")
        print()

    local_indices, global_indices = transform_indices[0]
    local_block, global_block = blocks[1]
    assert len(local_indices) == len(local_block.srcdata[dgl.NID])
    assert len(global_indices) == len(global_block.srcdata[dgl.NID])
    